{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class PairDistanceLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PairDistanceLoss, self).__init__()\n",
    "        \n",
    "    def forward(self, x, y):\n",
    "        # x：预测分数(N, Class)\n",
    "        # y: 实际标签(N, Class)\n",
    "        N = x.size(0)\n",
    "        C = x.size(1)\n",
    "        loss = 0.0\n",
    "        for i in range(N):\n",
    "            ni = torch.sum(y[i]).int().data[0]\n",
    "            true_xi, false_xi = x[i][y[i]==0], x[i][y[i]==1]\n",
    "            true_vec, false_vec = true_xi.expand(ni, C-ni), false_xi.expand(C-ni, ni).t()\n",
    "            loss += 1./(ni*(C-ni))*torch.sum(torch.exp(-(true_vec-false_vec)))\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       " 3.0862\n",
       "[torch.FloatTensor of size 1]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.autograd import Variable\n",
    "loss = PairDistanceLoss()\n",
    "x = torch.zeros(2, 3)\n",
    "x[0][0], x[0][1], x[0][2], x[1][0], x[1][1], x[1][2] = 1, 2, 3, 4, 5, 6\n",
    "x = Variable(x)\n",
    "y = torch.zeros(2, 3)\n",
    "y[0][1] = y[1][0] = y[1][2] = 1\n",
    "y = Variable(y)\n",
    "loss(x, y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
